import numpy as np
from numpy import *
import scipy
from scipy.optimize import minimize 
from scipy.integrate import nquad
import matplotlib.pyplot as plt
import matplotlib
import time
import sys
from numba import jit
from scipy.stats import ortho_group
import h5py
import matplotlib.patheffects as pe
from scipy.io import savemat
from scipy.optimize import curve_fit
#########################################
#Use input 96 for agreement with exp work arXiv:2109.12631 (corresponds to D_0=0)
################################################################################################################################

matplotlib.use('Agg')
np.set_printoptions(threshold=np.inf)
KineticTerm=[]
nptsd=10
npts=3*nptsd**2-3*nptsd+1-2*nptsd-(nptsd-2)
a=1.42*10**(-10) 
N=1
# Dirac momentum of 1 layer of graphene
kD=4*np.pi/(3*a*sqrt(3))
# Twist angle
theta=1.53*np.pi/180
# momentum scale of MBZ - ktheta is length of one side of hexagon of MBZ
ktheta=2*kD*np.sin(np.abs(theta)/2)



q1x=0
q1y=-ktheta

q2x=ktheta*np.sqrt(3)/2
q2y=ktheta*1/2

q3x=-ktheta*np.sqrt(3)/2
q3y=ktheta*1/2

# Bravais lattice vectors in MBZ
b1x=np.sqrt(3)/2*ktheta
b1y=3/2*ktheta
b2x=np.sqrt(3)/2*ktheta
b2y=-3/2*ktheta

# Number of shells used to construct MBZ - ie add another "shell" of Moire unit cells nshells*b_reciprocal from gamma point of zeroth MBZ
nshells=3
nactiveshells=1
radius=np.sqrt((nshells*b2x+q1x)**2+(nshells*b2y+q1y)**2)+.001*ktheta
shell1momentax=np.array([b1x,b2x,-b1x,-b2x,b1x+b2x,-b1x-b2x])
shell1momentay=np.array([b1y,b2y,-b1y,-b2y,b1y+b2y,-b1y-b2y])

shell2momentax=np.array([2*b1x,2*b2x,-2*b1x,-2*b2x,2*b1x+2*b2x,-2*b1x-2*b2x,2*b1x+b2x,-2*b1x-b2x,2*b2x+b1x,-2*b2x-b1x,b1x-b2x,b2x-b1x])
shell2momentay=np.array([2*b1y,2*b2y,-2*b1y,-2*b2y,2*b1y+2*b2y,-2*b1y-2*b2y,2*b1y+b2y,-2*b1y-b2y,2*b2y+b1y,-2*b2y-b1y,b1y-b2y,b2y-b1y])

# store layer index for each q point (layer 1 and 3 will share some)
qvalslayer=[]

for i in range(-10,10):
    for j in range(-10,10):
        if np.sqrt((i*b1x+j*b2x+q1x)**2+(i*b1y+j*b2y+q1y)**2)<radius:


            qvalslayer.append(2)
for i in range(-10,10):
    for j in range(-10,10):
        if np.sqrt((i*b1x+j*b2x+q1x+q2x)**2+(i*b1y+j*b2y+q1y+q2y)**2)<radius:


            qvalslayer.append(1)

for i in range(-10,10):
    for j in range(-10,10):
        if np.sqrt((i*b1x+j*b2x+q1x+q2x)**2+(i*b1y+j*b2y+q1y+q2y)**2)<radius:


            qvalslayer.append(3)

# store layer index for each q point (layer 1 and 3 will share some)




nbands=np.size(qvalslayer)*2
nbandsminus=nbands-2
nbandsplus=nbands+2
nactive=nbandsplus-nbandsminus

# REPLACE REFERS TO MY DIRECTORIES, REPLACE WITH OWN PATHS IF USING
HtotalEigVecs=np.load('/n/holylfs/LABS/sachdev_lab/vectors2/HtotalEigVecsrepo_%d.npz'%(int(sys.argv[1]),))
HtotalEigVecsShell1=np.load('/n/holylfs/LABS/sachdev_lab/vectors2/HtotalEigVecsShell1repo_%d.npz'%(int(sys.argv[1]),))
HtotalEigVecsShell2=np.load('/n/holylfs/LABS/sachdev_lab/vectors2/HtotalEigVecsShell2repo_%d.npz'%(int(sys.argv[1]),))
HtotalEigVals=np.load('/n/holylfs/LABS/sachdev_lab/vectors2/HtotalEigValsrepo_%d.npz'%(int(sys.argv[1]),))
HtotalEigVals = HtotalEigVals.f.arr_0
HtotalEigVecs = HtotalEigVecs.f.arr_0
HtotalEigVecsShell1 = HtotalEigVecsShell1.f.arr_0
HtotalEigVecsShell2 = HtotalEigVecsShell2.f.arr_0

for i in range(npts):
    KineticTerm.append(np.diag(HtotalEigVals[i]))

kxvals=np.load('/n/home11/mchristos/TTG/kxvalsrepo.npy')
kyvals=np.load('/n/home11/mchristos/TTG/kyvalsrepo.npy')




Htotal=np.load('/n/home11/mchristos/TTG/epsvary/eps10noPHtest2HtotalLowerBandsrepo_%d.npy'%(int(sys.argv[1]),))



def eigsystemh(H):
    w, v = np.linalg.eig(H)
    idx = np.argsort(w)
    w = w[idx]
    v = v[:,idx]
    return w,v



rangeE=80
nenergy=1600
energyStep=np.linspace(-rangeE,rangeE,num=nenergy)
energyspacing=160


DOSFinal=np.full((nenergy),0.+0.j)
qsize=np.size(qvalslayer)

qvalslayer=np.asarray(qvalslayer)
@jit(nopython=True)
def MakeDOS(Hvals,Hvecs,DOS):

	for k in range(nenergy):
		print('k: ',k)
		for pts in range(npts):


			for a in range(nactive*4):

				if np.abs(energyStep[k]-Hvals[pts][a])>energyspacing:
					continue

				vec5=Hvecs[pts][:,a]
				vecunproj5=HtotalEigVecs[pts].dot(vec5)


				
				for i in range(qsize):
					if qvalslayer[i]!=1:
						continue

					for j in range(qsize):
						if qvalslayer[j]!=1:
							continue


						DOS[k]=DOS[k]-1/npts*1/(2*np.pi)**2*( np.conjugate(vecunproj5[2*i])*(vecunproj5[2*j]) + np.conjugate(vecunproj5[2*i+1])*(vecunproj5[2*j+1])  )*np.imag(1/(energyStep[k]-Hvals[pts][a]+5j))

						DOS[k]=DOS[k]-1/npts*1/(2*np.pi)**2*( np.conjugate(vecunproj5[2*i])*(vecunproj5[nbands+2*j]) + np.conjugate(vecunproj5[2*i+1])*(vecunproj5[nbands+2*j+1])  )*np.imag(1/(energyStep[k]-Hvals[pts][a]+5j))

						DOS[k]=DOS[k]-1/npts*1/(2*np.pi)**2*( np.conjugate(vecunproj5[nbands+2*i])*(vecunproj5[nbands+2*j]) + np.conjugate(vecunproj5[nbands+2*i+1])*(vecunproj5[nbands+2*j+1])  )*np.imag(1/(energyStep[k]-Hvals[pts][a]+5j))

						DOS[k]=DOS[k]-1/npts*1/(2*np.pi)**2*( np.conjugate(vecunproj5[nbands+2*i])*(vecunproj5[2*j]) + np.conjugate(vecunproj5[nbands+2*i+1])*(vecunproj5[2*j+1])  )*np.imag(1/(energyStep[k]-Hvals[pts][a]+5j))

	return DOS


print('starting numba')

eigvals=np.full((npts,4*nactive),0.+0.j)
eigvecs=np.full((npts,4*nactive,4*nactive),0.+0.j)
for i in range(npts):
	eigvals[i]=eigsystemh(Htotal[i])[0]
	eigvecs[i]=eigsystemh(Htotal[i])[1]
MakeDOS(eigvals,eigvecs,DOSFinal)




plt.title('$w_1$=125 meV, $w_0$=92 meV, $v_0$=1.0$\\times 10^{6}$ m/s',fontsize=14)

plt.plot(energyStep,DOSFinal,label='L.B. $\epsilon=10$, $\eta=5$',color='firebrick')
plt.yticks([])
plt.ylabel('LDOS (normalized)',fontsize=12)
plt.xlim(-rangeE,rangeE)
plt.xlabel('Energy (meV)',fontsize=12)
plt.legend(loc='upper left', edgecolor="white")
plt.savefig('DOS.pdf')


# Save files (DOS)
np.save('DOSFinalcompare',DOSFinal)


# Save files in matlab firendly format - save both energy scale and LDOS
mdic1 = {"a": DOSFinal, "label": "LDOSyaxis"}
mdic2= {"a": energyStep, "label": "Energyxaxis"}

savemat("LDOStest1.mat", mdic1)
savemat("Energytest1.mat", mdic2)



